#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sat May 14 23:34:11 2022

Simulation of sliced 1-Wasserstein distance
"""

import ot
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import gaussian_kde
from fractions import Fraction

def generate_uniform_sphere(d,n,R):
    data = np.zeros((n,d))
    for j in range(n):
        temp = np.random.normal(0,1,size=(1,d))
        data[j] = R*temp/np.linalg.norm(temp)
    return data


#compute the covariance of limiting Gaussian process
def cov_simulate(d,m):
    cov = 0
    angles = generate_uniform_sphere(d,m,1)
    for i in range(m):
        u = angles[i,:]
        for j in range(m):
            v = angles[j,:]
            sign = np.sign(sum(u))*np.sign(sum(v))
            cov += (2*sign*np.dot(u,v)/3)/m**2
    return cov 



R = 1
d = 3
rswd = 0.867
vaS = cov_simulate(d,1000)

n = 1000
B = 500
replacement = [1,7/8,3/4,1/2]

m = 1
xs = np.linspace(-2,2,200)
limSdens = np.exp(-xs**2/(2*vaS))/np.sqrt(2*vaS*np.pi)

n_seed = 10
swd = np.empty((500,))
boot = np.empty((4,500))
a, b = np.ones((n,)) / n, np.ones((n,)) / n

#generate samples and compute the sliced 1-Wasserstein distance between the empirical distributions
for i in range(500):
    smp = np.empty((n_seed,))
    datap = generate_uniform_sphere(d,n,R)
    dataq = generate_uniform_sphere(d,n,R)+1    
    for seed in range(n_seed):
        smp[seed] = ot.sliced.sliced_wasserstein_distance(datap, dataq, a, b, 1000, p=1, seed=seed)
    swd[i] = np.sqrt(n)*(np.mean(smp) - rswd)
densswd = gaussian_kde(swd,'silverman')

#perform bootstrap with various replacements
m = 1
for li in replacement:
    l = int(np.power(n,li))
    al, bl = np.ones((l,)) / l, np.ones((l,)) / l
    btp = np.empty((n_seed,))
    datap = generate_uniform_sphere(d,n,R)
    dataq = generate_uniform_sphere(d,n,R)+1
    for seed in range(n_seed):
            smp[seed] = ot.sliced.sliced_wasserstein_distance(datap, dataq, a, b, 1000, p=1, seed=seed)
    swd_boot = np.mean(smp)
    for j in range(B):
        indices = np.random.choice(n,l,replace = True)
        rep = datap[indices]
        req = dataq[indices]
        for seed in range(n_seed):
            btp[seed] = ot.sliced.sliced_wasserstein_distance(rep, req, al, bl, 1000, p=1, seed=seed)
        boot[m-1,j] = np.sqrt(l)*(np.mean(btp) - swd_boot)
    
    densboot = gaussian_kde(boot[m-1,:],'silverman')
    plt.figure(m)
    plt.plot(xs,densswd(xs),color='cadetblue')
    plt.fill_between(xs,densswd(xs),color='paleturquoise',alpha=0.3)
    plt.plot(xs,densboot(xs),color = 'darkolivegreen')
    plt.fill_between(xs,densboot(xs),color='palegreen',alpha = 0.4)
    plt.plot(xs,limSdens,color='palevioletred')
    plt.fill_between(xs,limSdens,color='pink',alpha=0.3)
    plt.xlabel("x")
    plt.ylabel("Density")
    plt.title('replacement l = n^'+str(Fraction(li)))
    m += 1




